""" DiffGro Planner Implementation """
from functools import partial
from typing import Any, Tuple, List, Dict, Union, Type, Optional, Callable

import time
import gym
import jax
import jax.numpy as jnp
import numpy as np
import haiku as hk
import wandb

from sb3_jax.common.offline_algorithm import OfflineAlgorithm
from sb3_jax.common.buffers import BaseBuffer
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.jax_utils import jax_print, jit_optimize, stop_grad

from diffgro.utils.utils import print_b
from diffgro.common.buffers import TrajectoryBuffer
from diffgro.common.models.utils import kl_div
from diffgro.diffgro.policies import DiffGroPlannerPolicy, apply_mask


class DiffGroPlanner(OfflineAlgorithm):
    def __init__(
        self,
        policy: Union[str, Type[DiffGroPlannerPolicy]],
        env: Union[GymEnv, str],
        replay_buffer: Type[BaseBuffer] = TrajectoryBuffer,
        learning_rate: float = 3e-4,
        batch_size: int = 256, 
        gamma: float = 0.99,
        gradient_steps: int = 1,
        beta: float = 1e-4, # kl regularization weight
        wandb_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = 1,
        _init_setup_model: bool = True,
    ):
        super(DiffGroPlanner, self).__init__(
            policy,
            env,
            replay_buffer=replay_buffer,
            learning_rate=learning_rate,
            batch_size=batch_size,
            gamma=gamma,
            gradient_steps=gradient_steps,
            tensorboard_log=None,
            wandb_log=wandb_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            create_eval_env=False,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(gym.spaces.Box),
            support_multi_env=False,
        )
        self.learning_rate = learning_rate
        self.beta = beta

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        self.set_random_seed(self.seed)
        self.policy = self.policy_class(
            self.observation_space,
            self.action_space,
            self.learning_rate,
            seed=self.seed,
            **self.policy_kwargs,
        )
        self._create_aliases()

        # generate masking
        horizon = self.act.horizon

        mask_obs = jnp.tril(jnp.ones((horizon, horizon))) # [[1,0,0],[1,1,0],[1,1,1]]
        mask_obs = jnp.tile(jnp.expand_dims(mask_obs, -1), (self.batch_size, 1, 1))  # [batch * horizon, horizon, 1]   
        mask_obs = jnp.repeat(mask_obs, self.act.obs_dim, axis=-1)

        mask_act = jnp.tril(jnp.ones((horizon, horizon))) # [[1,0,0],[1,1,0],[1,1,1]]
        mask_act = jnp.concatenate((jnp.zeros((1, horizon)), mask_act), axis=0)
        mask_act = mask_act[:-1] # [[0,0,0],[1,0,0],[1,1,0]]
        mask_act = jnp.tile(jnp.expand_dims(mask_act, -1), (self.batch_size, 1, 1))
        mask_act = jnp.repeat(mask_act, self.act.act_dim, axis=-1)

        self.mask = jnp.concatenate((mask_obs, mask_act), axis=-1)

    def _create_aliases(self) -> None:
        self.act = self.policy.act
        self.pri = self.policy.pri

    def train(self, gradient_steps: int, batch_size: int = 256) -> None:
        pi_losses, kl_losses, du_losses = [], [], []
        pr_losses = []

        for gradient_step in range(gradient_steps):
            self._n_updates += 1
            batch_keys = ['tasks', 'observations', 'actions', 'skills']
            replay_data = self.replay_buffer.sample(batch_keys, batch_size, max_length=self.act.horizon) 
            task    = replay_data.tasks
            obs     = replay_data.observations
            p_obs   = self.policy.preprocess(obs.reshape(-1, self.act.obs_dim), training=True)
            p_obs   = p_obs.reshape(batch_size, -1, self.act.obs_dim)
            act     = replay_data.actions
            skill   = replay_data.skills
            
            # 1. learning encoder & decoder (planner)
            self.act.optim_state, self.act.params, pi_loss, pi_info = jit_optimize(
                self._pi_loss,
                self.act.optim,
                self.act.optim_state,
                self.act.params,
                max_grad_norm=None,
                obs=p_obs,
                act=act,
                rng=next(self.policy.rng)
            )
            pi_losses.append(pi_loss)
            kl_losses.append(pi_info["kl_loss"])
            du_losses.append(pi_info["du_loss"])
            end = time.time()

            # 2. learning prior
            self.pri.optim_state, self.pri.params, pr_loss, pr_info = jit_optimize(
                self._pr_loss,
                self.pri.optim,
                self.pri.optim_state,
                self.pri.params,
                max_grad_norm=None,
                obs=p_obs,
                task=task,
                skill=skill,
                q_mean=stop_grad(pi_info["q_mean"]),
                q_std=stop_grad(pi_info["q_std"]),
                rng=next(self.policy.rng)
            )
            pr_losses.append(pr_loss)

        wandb_log = {"time/total_timesteps": self.num_timesteps}
        if self._n_updates % self.log_interval == 0:
            # reconstruction mse
            pred_traj, _ = self.act._denoise_act(
                pi_info["cond"], self.mask, pi_info["q_mean"], skill, None, deterministic=True)
            action_mse = jnp.mean(jnp.square(jnp.repeat(act, self.act.horizon, axis=0) - pred_traj[:,:,-self.act.act_dim:]))
            self.logger.record("train/pi/recon_mse", action_mse)
            wandb_log.update({"train/pi/recon_mse": action_mse})

        self.logger.record("train/batch_size", batch_size)
        self.logger.record("train/pi/loss", np.mean(pi_losses))
        self.logger.record("train/pi/loss_kl", np.mean(kl_losses))
        self.logger.record("train/pi/loss_du", np.mean(du_losses))
        self.logger.record("train/pr/loss", np.mean(pr_losses))
        wandb_log.update({"train/pi/loss": np.mean(pi_losses)})
        wandb_log.update({"train/pi/loss_kl": np.mean(kl_losses)})
        wandb_log.update({"train/pi/loss_du": np.mean(du_losses)})
        wandb_log.update({"train/pr/loss": np.mean(pr_losses)})
        if self.wandb_log is not None:
            wandb.log(wandb_log)
    
    @partial(jax.jit, static_argnums=0)
    def _pi_loss(
        self,
        pi_params: hk.Params,
        obs: jax.Array,
        act: jax.Array,
        rng=None,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_size = obs.shape[0]
        rng_t, rng_n, rng_p, rng = jax.random.split(rng, num=4)
        
        # 1. foward process
        horizon = self.act.horizon
        noise = jax.random.normal(rng_n, shape=(batch_size * horizon, horizon, self.act.out_dim))
        traj = jnp.concatenate((obs, act), axis=-1)
        traj = jnp.repeat(traj, horizon, axis=0) # repeat
        ts = jax.random.randint(rng_t, shape=(batch_size, 1), minval=1, maxval=self.act.n_denoise+1)
        ts = jnp.repeat(ts, horizon, axis=0) # repeat
        sqrtab = jnp.repeat(self.act.ddpm_dict.sqrtab[ts], horizon, axis=0).reshape(batch_size * horizon, -1, 1)
        sqrtmab = jnp.repeat(self.act.ddpm_dict.sqrtmab[ts], horizon, axis=0).reshape(batch_size * horizon, -1, 1)
        x_t = sqrtab * traj + sqrtmab * noise

        # 2. condition with inpainting
        cond = traj
        x_t = apply_mask(self.mask, x_t, cond)

        # 3. diffusion inference
        enc_batch_dict = {"obs": obs, "act": act}
        dec_batch_dict = {}
        batch_dict = {"enc": enc_batch_dict, "dec": dec_batch_dict}
        (mean, std, noise_pred), info = self.act._pi(x_t, batch_dict, ts, None, False, False, pi_params, rng_p) 
        
        # 4. loss calculation
        # regularization loss
        kl_loss = kl_div(mean, std)
        kl_loss = jnp.mean(kl_loss)
        # diffusion loss
        noise_pred = apply_mask(self.mask, noise_pred, cond)
        if self.act.predict_epsilon: # prediction of noise
            noise = apply_mask(self.mask, noise, cond)
            du_loss = jnp.mean(jnp.square(noise_pred - noise))
        else: # prediction of original traj
            du_loss = jnp.mean(jnp.square(noise_pred - traj))
        loss = du_loss + self.beta * kl_loss
        return loss, {"kl_loss": kl_loss, "du_loss": du_loss, "q_mean": mean, "q_std": std, "cond": cond}

    @partial(jax.jit, static_argnums=0)
    def _pr_loss(
        self,
        pr_params: hk.Params,
        obs: jax.Array,
        task: jax.Array,
        skill: jax.Array,
        q_mean: jax.Array,
        q_std: jax.Array,
        rng=None,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        obs = obs[:,0,:] # retrieve s_1
        
        # 1. prior inference
        batch_dict = {"obs": obs, "task": task, "skill": skill}
        p_mean, p_std = self.pri._pr(batch_dict, pr_params, rng)

        # 2. loss calculation
        kl_loss = kl_div(p_mean, p_std, q_mean, q_std)
        kl_loss = jnp.mean(kl_loss)

        return kl_loss, {"p_mean": p_mean, "p_std": p_std}

    def learn(
        self,
        total_timesteps: Tuple[int, int],
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "DiffGroPlanner",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "DiffGroPlanner":
        self.log_interval = log_interval

        # wandb configs
        if self.wandb_log is not None:
            self.wandb_config = dict(
                time=time.ctime(),
                algo='diffgro/planner',
                tag=self.wandb_log['tag'],
                learning_rate=self.learning_rate,
                batch_size=self.batch_size,
                gamma=self.gamma,
                gradient_steps=self.gradient_steps,
                seed=self.seed,
            )
            self.wandb_config.update(self.policy._get_constructor_parameters())
        
        total_timesteps, callback = self._setup_learn(
            total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
        )
        callback.on_training_start(locals(), globals())

        # 2. learn policy module
        start_time = time.time()
        num_timesteps = 0
        while num_timesteps < total_timesteps:
            self.train(gradient_steps=self.gradient_steps, batch_size=self.batch_size)

            self.num_timesteps += 1
            num_timesteps += 1
            if log_interval is not None and num_timesteps % log_interval == 0:
                fps = int(num_timesteps / (time.time() - start_time))
                self.logger.record("time/fps", fps)
                self.logger.record("time/time_elapsed", int(time.time() - start_time), exclude="tensorboard")
                self.logger.record("time/total_timesteps", num_timesteps, exclude="tensorboard")
                self.logger.dump(step=num_timesteps)
            
            callback.update_locals(locals())
            if callback.on_step() is False:
                return False
        
        callback.on_training_end()
        return self

    def load_params(self, path: str) -> None:
        print_b(f"[diffgro] : loading params")
        data, params = load_from_zip_file(path, verbose=1)
        self._load_jax_params(params)
        self._load_norm_layer(path)

    def _save_jax_params(self) -> Dict[str, hk.Params]:
        params_dict = {} 
        params_dict["pi_params"] = self.act.params
        params_dict["pr_params"] = self.pri.params
        return params_dict
    
    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        self.act._load_jax_params(params)
        self.pri._load_jax_params(params)

    def _save_norm_layer(self, path: str) -> None:
        if self.policy.normalization_class is not None:
            self.policy.normalization_layer.save(path)

    def _load_norm_layer(self, path: str) -> None:
        if self.policy.normalization_class is not None:
            self.policy.normalization_layer = self.policy.normalization_layer.load(path)
